import os
import sys
from dataclasses import dataclass

script_file_path = __file__
script_dir_path = os.path.dirname(script_file_path)
project_dir_path = os.path.join(script_dir_path, "../../..")
sys.path.append(project_dir_path)

from scripts.training.launching import common_mixture_cli
from scripts.training.training import CommonTrainingHparams
from scripts.training.synergy.model import Llama3SynergyTrainingModel, Llama3SynergyTrainingModelHparams


@dataclass(kw_only=True)
class Llama3SynergyTrainingHparams(CommonTrainingHparams):
    model: Llama3SynergyTrainingModelHparams


if __name__ == '__main__':
    common_mixture_cli(
        hparams_cls=Llama3SynergyTrainingHparams,
        model_factory=Llama3SynergyTrainingModel,
        script_file_path=script_file_path,
        hparams_file_path=os.path.join(script_dir_path, "hparams.json"),
        trainings_dir_path=os.path.join(project_dir_path, "trainings"))
